{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python361064bitpy36conda4f32697db52543acaedb8346d569b572",
   "display_name": "Python 3.6.10 64-bit ('py36': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# 使用huggingface全家桶(transformers, datasets)实现一条龙BERT训练(trainer)和预测(pipeline)\n",
    "\n",
    "huggingface的[transformers](https://github.com/huggingface/transformers)在我写下本文时已有39.5k star，可能是目前最流行的深度学习库了，而这家机构又提供了[datasets](https://github.com/huggingface/datasets)这个库，帮助快速获取和处理数据。这一套全家桶使得整个使用BERT类模型机器学习流程变得前所未有的简单。\n",
    "\n",
    "不过，目前我在网上没有发现比较简单的关于整个一套全家桶的使用教程。所以写下此文，希望帮助更多人快速上手。\n",
    "\n",
    "这里，我们以`AGNews`新闻分类任务为例，演示整套流程的实现。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"  # 在此我指定使用2号GPU，可根据需要调整\n",
    "import torch\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from transformers import pipeline\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "source": [
    "## 使用datasets读取数据集\n",
    "\n",
    "下面的代码读取原始数据集的train部分的前40000条作为我们的训练集，40000-50000条作为开发集（只使用这个子集已经可以训出不错的模型，并且可以让训练时间更短），原始的测试集作为我们的测试集。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Using custom data configuration default\n",
      "Reusing dataset ag_news (/home/zhiling/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)\n",
      "Using custom data configuration default\n",
      "Reusing dataset ag_news (/home/zhiling/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)\n",
      "Using custom data configuration default\n",
      "Reusing dataset ag_news (/home/zhiling/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)\n",
      "Dataset({\n",
      "    features: ['text', 'label'],\n",
      "    num_rows: 40000\n",
      "})\n",
      "Dataset({\n",
      "    features: ['text', 'label'],\n",
      "    num_rows: 10000\n",
      "})\n",
      "Dataset({\n",
      "    features: ['text', 'label'],\n",
      "    num_rows: 7600\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "train_dataset = load_dataset(\"ag_news\", split=\"train[:40000]\")\n",
    "dev_dataset = load_dataset(\"ag_news\", split=\"train[40000:50000]\")\n",
    "test_dataset = load_dataset(\"ag_news\", split=\"test\")\n",
    "print(train_dataset)\n",
    "print(dev_dataset)\n",
    "print(test_dataset)"
   ]
  },
  {
   "source": [
    "原始数据集包含text和label两个字段"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'text': Value(dtype='string', id=None),\n",
       " 'label': ClassLabel(num_classes=4, names=['World', 'Sports', 'Business', 'Sci/Tech'], names_file=None, id=None)}"
      ]
     },
     "metadata": {},
     "execution_count": 9
    }
   ],
   "source": [
    "train_dataset.features"
   ]
  },
  {
   "source": [
    "由于bert模型期望得到的标签的字段为`labels`而原始数据集中的名字是`label`，所以做一下调整。\n",
    "\n",
    "下面的代码把`label`字段复制到`labels`。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Loading cached processed dataset at /home/zhiling/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-fe78f98825d1041c.arrow\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'label': 2,\n",
       " 'labels': 2,\n",
       " 'text': \"Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\\\band of ultra-cynics, are seeing green again.\"}"
      ]
     },
     "metadata": {},
     "execution_count": 10
    }
   ],
   "source": [
    "train_dataset = train_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)\n",
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Loading cached processed dataset at /home/zhiling/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-a66ed2efed56e4c9.arrow\n",
      "Loading cached processed dataset at /home/zhiling/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-5f9134046515027e.arrow\n"
     ]
    }
   ],
   "source": [
    "dev_dataset = dev_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)\n",
    "test_dataset = test_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)"
   ]
  },
  {
   "source": [
    "## 加载模型，tokenizer，并预处理数据\n",
    "\n",
    "为了快速实验，我们选择一个较小的`bert-tiny`模型进行实验。加载对应模型和tokenizer"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model_id = 'prajjwal1/bert-tiny'\n",
    "# note that we need to specify the number of classes for this task\n",
    "# we can directly use the metadata (num_classes) stored in the dataset\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_id, \n",
    "            num_labels=train_dataset.features[\"label\"].num_classes)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)"
   ]
  },
  {
   "source": [
    "用bert的方法对数据集做分词预处理，把所有序列补充或截断到256个token"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "49b10b93f84e4e78bc0a2fee2dbff4d2"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "9764e9a80b73409eb152b936c99048ea"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n",
      "Loading cached processed dataset at /home/zhiling/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a/cache-0ece2b3c35fe711f.arrow\n"
     ]
    }
   ],
   "source": [
    "MAX_LENGTH = 256\n",
    "train_dataset = train_dataset.map(lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)\n",
    "dev_dataset = dev_dataset.map(lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)\n",
    "test_dataset = test_dataset.map(lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)"
   ]
  },
  {
   "source": [
    "为了放进pytorch模型训练，还要再声明格式和使用的字段"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])\n",
    "dev_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])\n",
    "test_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])"
   ]
  },
  {
   "source": [
    "现在我们的训练样本长这样，可以直接放进bert训练了"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),\n",
       " 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),\n",
       " 'label': ClassLabel(num_classes=4, names=['World', 'Sports', 'Business', 'Sci/Tech'], names_file=None, id=None),\n",
       " 'labels': Value(dtype='int64', id=None),\n",
       " 'text': Value(dtype='string', id=None),\n",
       " 'token_type_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}"
      ]
     },
     "metadata": {},
     "execution_count": 16
    }
   ],
   "source": [
    "train_dataset.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n",
       " 'input_ids': tensor([  101,  2813,  2358,  1012,  6468, 15020,  2067,  2046,  1996,  2304,\n",
       "          1006, 26665,  1007, 26665,  1011,  2460,  1011, 19041,  1010,  2813,\n",
       "          2395,  1005,  1055,  1040, 11101,  2989,  1032,  2316,  1997, 11087,\n",
       "          1011, 22330,  8713,  2015,  1010,  2024,  3773,  2665,  2153,  1012,\n",
       "           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "             0,     0,     0,     0,     0,     0]),\n",
       " 'labels': tensor(2),\n",
       " 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}"
      ]
     },
     "metadata": {},
     "execution_count": 15
    }
   ],
   "source": [
    "train_dataset[0]"
   ]
  },
  {
   "source": [
    "我们可以指定模型训练时，显示的验证指标"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(pred):\n",
    "    labels = pred.label_ids\n",
    "    preds = pred.predictions.argmax(-1)\n",
    "    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')\n",
    "    acc = accuracy_score(labels, preds)\n",
    "    return {\n",
    "        'accuracy': acc,\n",
    "        'f1': f1,\n",
    "        'precision': precision,\n",
    "        'recall': recall\n",
    "    }"
   ]
  },
  {
   "source": [
    "指定训练参数，使用trainer直接训练"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<IPython.core.display.HTML object>",
      "text/html": "\n    <div>\n        <style>\n            /* Turns off some styling */\n            progress {\n                /* gets rid of default border in Firefox and Opera. */\n                border: none;\n                /* Needs to be in here for Safari polyfill so background images work as expected. */\n                background-size: auto;\n            }\n        </style>\n      \n      <progress value='2' max='1875' style='width:300px; height:20px; vertical-align: middle;'></progress>\n      [   2/1875 : < :, Epoch 0.00/3]\n    </div>\n    <table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: left;\">\n      <th>Epoch</th>\n      <th>Training Loss</th>\n      <th>Validation Loss</th>\n    </tr>\n  </thead>\n  <tbody>\n  </tbody>\n</table><p>"
     },
     "metadata": {}
    }
   ],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',          # output directory\n",
    "    learning_rate=3e-4,\n",
    "    num_train_epochs=3,              # total number of training epochs\n",
    "    per_device_train_batch_size=64,  # batch size per device during training\n",
    "    per_device_eval_batch_size=64,   # batch size for evaluation\n",
    "    logging_dir='./logs',            # directory for storing logs\n",
    "    logging_steps=100,\n",
    "    do_train=True,\n",
    "    do_eval=True,\n",
    "    no_cuda=False,\n",
    "    load_best_model_at_end=True,\n",
    "    # eval_steps=100,\n",
    "    evaluation_strategy=\"epoch\"\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,                         # the instantiated 🤗 Transformers model to be trained\n",
    "    args=training_args,                  # training arguments, defined above\n",
    "    train_dataset=train_dataset,         # training dataset\n",
    "    eval_dataset=dev_dataset,            # evaluation dataset\n",
    "    compute_metrics=compute_metrics\n",
    ")\n",
    "\n",
    "train_out = trainer.train()"
   ]
  },
  {
   "source": [
    "## 使用pipeline直接对文本进行预测\n",
    "\n",
    "`pipeline`可以直接加载训练好的模型和tokenizer，然后直接对文本进行分类预测，无需再自行预处理\n",
    "\n",
    "首先我们把模型放回cpu来进行预测"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.cpu()"
   ]
  },
  {
   "source": [
    "用`sentiment-analysis`来指定我们做的是文本分类任务（情感分析是一类代表性的文本分类任务），并指定我们之前训好的模型。"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = pipeline('sentiment-analysis', model=model, tokenizer=tokenizer)"
   ]
  },
  {
   "source": [],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "我们从模型没有见过的test集里挑一个例子来进行预测"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Using custom data configuration default\n",
      "Reusing dataset ag_news (/home/zhiling/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'label': 2,\n",
       " 'text': \"Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.\"}"
      ]
     },
     "metadata": {},
     "execution_count": 28
    }
   ],
   "source": [
    "test_examples = load_dataset(\"ag_news\", split=\"test[:10]\")\n",
    "test_examples[0]"
   ]
  },
  {
   "source": [
    "该文本的类别为2，看看模型能不能做出正确预测？"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "[{'label': 'LABEL_2', 'score': 0.9601152539253235}]"
      ]
     },
     "metadata": {},
     "execution_count": 22
    }
   ],
   "source": [
    "result = classifier(test_examples[0]['text'])\n",
    "result"
   ]
  },
  {
   "source": [
    "预测正确！\n",
    "\n",
    "到此我们的huggingface全家桶就大功告成了~\n",
    "\n",
    "本文的完全代码可以直接在这里找到：[https://github.com/blmoistawinde/hello\\_world/blob/master/huggingface\\_classification.ipynb](https://github.com/blmoistawinde/hello_world/blob/master/huggingface_classification.ipynb)"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}